Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon 2nd No.16] add API RReLU #42466

Conversation

OccupyMars2025
Copy link
Contributor

@OccupyMars2025 OccupyMars2025 commented May 4, 2022

PR types

New features

PR changes

APIs

Describe

add RReLU activation function for Paddle #40317

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 4, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@dingjiaweiww
Copy link
Contributor

请先通过CI噢~

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 5, 2022

PR格式检查通过,你的PR将接受Paddle专家以及开源社区的review,请及时关注PR动态。
The format inspection passed. Your PR will be reviewed by experts of Paddle and developers from the open-source community. Stay tuned.

@zhiboniu
Copy link
Contributor

zhiboniu commented May 5, 2022

随机数改成调封装好的函数吧(具体可以参考现有调用的代码):
funcs::uniform_distribution dist;
funcs::uniform_real_transform trans(min, max);
funcs::distribution_and_transform(dev_ctx, out, dist, trans);

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.fix_seed这个参数应该不需要,seed 也不需要,目前的2.0后的写法是直接从generator中获取seed,不需要单独去给OP设置一个seed

auto gen_cuda = ctx.GetGenerator();
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;

2.GPU kernel 可以参考下 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/exponential_op.cu#L29-L31 中的写法来获取均匀分布:

using MT = typename kps::details::MPTypeTrait<T>::Type;
funcs::uniform_distribution<MT> dist;
funcs::uniform_real_transform<MT> trans(min, max);
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);

不用vec_size做向量化,直接拉起核函数就可以,可以参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/funcs/distribution_helper.h 中的distribution_and_transform 的封装形式

3.phi里面的写法不用framework::Tensor了,都用DenseTensor,另外 is_fix_seed、seed_val、const framework::Tensor* seed这些参数都应该不需要,不用过于参考dropout,因为里面有一些是1.0时候的老写法,历史兼容的原因。现在按ctx.GetGenerator可以直接获取seed,类似于:https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/poisson_kernel.cu

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented May 5, 2022

收到, 马上改

@OccupyMars2025
Copy link
Contributor Author

done

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented May 6, 2022

TODO: fix bug: if I uncomment those test cases that I have commented off in the file
python/paddle/fluid/tests/unittests/test_rrelu_op.py , then I will have the error "'X' contains uninitialized tensor"

@OccupyMars2025
Copy link
Contributor Author

you may need to add paddle.enable_static() at some places in the file
python/paddle/fluid/tests/unittests/test_rrelu_op.py

"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
// AddAttr<bool>("fix_seed",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些都可以去掉了

@OccupyMars2025
Copy link
Contributor Author

OccupyMars2025 commented May 16, 2022

感谢paddle专家的指导意见,我已经根据指导意见完成了主体开发,只剩下单测部分有问题。因为该任务,已经有人提交成功,我将暂停本PR的开发(以后有时间时,我会研究一下单测的问题)。

@paddle-bot-old
Copy link

很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。
Sorry to inform you that through our discussion, your PR fails to meet the merging standard (Reference: Paddle Custom Operator Design Doc). You can also submit an new one. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants